Fix aten_stft ONNX spec violations#2943
Merged
Merged
Conversation
…nk-3 signal) The ONNX STFT op requires: - a rank-3 signal of shape [batch, signal_length, 1], and - frame_step and frame_length to share the same (scalar) type. aten_stft previously passed a rank-1 frame_step (Reshape of hop_length) while frame_length (n_fft) was a rank-0 scalar, and only reshaped the signal up to rank 2. This produced STFT nodes that violate the spec. Fix by passing hop_length directly as the scalar frame_step and adding a trailing [1] dimension to the signal so it is rank 3 for both rank-1 and rank-2 torch.stft inputs. Adds an e2e regression test asserting the emitted STFT node is spec-compliant for rank-1 and rank-2 inputs. Fixes #2942 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2943 +/- ##
==========================================
+ Coverage 72.61% 72.64% +0.02%
==========================================
Files 259 259
Lines 31597 31766 +169
Branches 2973 3007 +34
==========================================
+ Hits 22945 23075 +130
- Misses 7643 7672 +29
- Partials 1009 1019 +10 ☔ View full report in Codecov by Harness. |
Contributor
There was a problem hiding this comment.
Pull request overview
Fixes ONNX STFT operator spec violations in the aten::stft lowering by ensuring the exported node receives spec-compliant input ranks and scalar parameters, and adds a regression test to prevent reintroduction.
Changes:
- Make the
STFTsignalinput rank-3 by appending a trailingUnsqueeze(-1)for both rank-1 and rank-2 Torch inputs. - Pass
hop_lengthdirectly to ONNXSTFTasframe_step(removing the prior reshape that produced a rank-1 tensor). - Add an end-to-end regression test that inspects the emitted ONNX
STFTnode inputs for rank/spec compliance.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
onnxscript/function_libs/torch_lib/ops/core.py |
Adjusts aten_stft preprocessing and STFT invocation so signal is rank-3 and frame_step/frame_length are scalar inputs per the ONNX spec. |
tests/function_libs/torch_lib/e2e_ops_tests.py |
Adds a regression test validating the exported ONNX graph contains a spec-compliant STFT node for both rank-1 and rank-2 inputs. |
gramalingam
reviewed
Jun 22, 2026
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
xadupre
approved these changes
Jun 23, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
aten_stftemitted an ONNXSTFTnode that violated the operator spec in two ways (reported in #2942 by @bas-aarts):frame_step/frame_lengthtype mismatch.frame_stepwas built viaop.Reshape(hop_length, [1]), producing a rank-1 tensor, whileframe_length(n_fft) is a rank-0 scalar. The spec (T2for both) requires them to share the same type/rank.STFTsignal must be rank 3 ([batch, signal_length, 1]for real input).torch.stftaccepts rank-1 or rank-2 signals, and the code only unsqueezed a rank-1 signal up to rank 2, leaving the signal one dimension short of the spec.Fix
hop_lengthdirectly asframe_step(a rank-0 scalar, matchingn_fft) and remove theframe_step_const = op.Reshape(...)line.hop_lengthis always a Python int in this trace-only function (input arg orn_fft // 4), so it is emitted as a rank-0 scalar exactly liken_fft.self = op.Unsqueeze(self, [-1])after the existing batch-dim handling so the signal becomes rank 3 for both rank-1 and rank-2 inputs.Rank trace (verified by inspecting the emitted STFT node)
Before → after, STFT inputs
[signal, frame_step, window, frame_length]:End-to-end shapes (fixed): rank-1 input
[L]→ batch unsqueeze[1, L]→ trailing unsqueeze[1, L, 1]→ STFT[1, frames, bins, 2]→ Transpose[1, bins, frames, 2]→ squeeze batch[bins, frames, 2]. Rank-2 input[B, L]→[B, L, 1]→[B, frames, bins, 2]→[B, bins, frames, 2]. Both matchtorch.stft's real output. Thenormalizedandonesidedpaths are unaffected (only dtype/scaling, not rank).Tests
ops.aten.stft) and_testing.assert_onnx_programvalue-comparison harnesses pass both before and after — they don't strictly enforce the STFT rank, which is why this spec bug went unnoticed.test_aten_stft_emits_spec_compliant_node(parameterized for rank-1 and rank-2 inputs) intests/function_libs/torch_lib/e2e_ops_tests.py, which asserts the emitted STFT node'ssignalis rank 3 andframe_step/frame_lengthare both scalar. This test fails on the old code (2 != 3) and passes with the fix.Ran:
pytest tests/function_libs/torch_lib/ops_test.py -k stft→ 2 passed, 1 skipped, 2 xfailedpytest tests/function_libs/torch_lib/e2e_ops_tests.py -k stft→ 6 passedlintrunner -a→ no lint issuesFixes #2942